// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_SYSTOLIC_ARRAY_H
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_SYSTOLIC_ARRAY_H

#include <mc_connections.h>
#include <systemc.h>

#include "ProcessingElement.h"
#include "Tieoff.h"
#include "src/AccelTypes.h"

template <typename IDTYPE, typename WDTYPE, typename ODTYPE, int NROWS,
          int NCOLS>
SC_MODULE(SystolicArray) {
 private:
  Connections::Combinational<IDTYPE> inputConnection[NROWS][NCOLS];
  Connections::Combinational<ODTYPE> psumConnection[NROWS - 1][NCOLS];
  sc_signal<WDTYPE> weightConnection[NROWS + 1][NCOLS];
  sc_signal<bool> weightValid;
  Connections::Combinational<ac_int<1, false> > weightSwap[NROWS * NCOLS - 1];
  Connections::Combinational<ac_int<1, false> > weightSwapFinal;

 public:
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<IDTYPE> inputs[NROWS];
  Connections::In<ac_int<1, false> > swapWeights[NROWS];
  Connections::In<Pack1D<IDTYPE, NCOLS> > CCS_INIT_S1(weights);
  Connections::In<ODTYPE> psums[NCOLS];
  Connections::Out<ODTYPE> outputs[NCOLS];
  Connections::SyncOut CCS_INIT_S1(weightSwapDone);

  SC_CTOR(SystolicArray) {
    ProcessingElement<IDTYPE, WDTYPE, ODTYPE> *pe[NROWS * NCOLS];
    for (int i = 0; i < NROWS; i++) {
      for (int j = 0; j < NCOLS; j++) {
        pe[i * NCOLS + j] = new ProcessingElement<IDTYPE, WDTYPE, ODTYPE>(
            sc_gen_unique_name("pe_inst"));
        pe[i * NCOLS + j]->clk(clk);
        pe[i * NCOLS + j]->rstn(rstn);

        if (j == 0) {
          pe[i * NCOLS + j]->inputIn(inputs[i]);
        } else {
          pe[i * NCOLS + j]->inputIn(inputConnection[i][j - 1]);
        }

        pe[i * NCOLS + j]->weightIn(weightConnection[i][j]);
        pe[i * NCOLS + j]->weightValid(weightValid);

        if (i == 0) {
          pe[i * NCOLS + j]->psumIn(psums[j]);
        } else {
          pe[i * NCOLS + j]->psumIn(psumConnection[i - 1][j]);
        }

        if (j == 0) {
          pe[i * NCOLS + j]->weightSwapIn(swapWeights[i]);
        } else {
          pe[i * NCOLS + j]->weightSwapIn(weightSwap[i * NCOLS + j - 1]);
        }

        if (i == NROWS - 1 && j == NCOLS - 1) {
          pe[i * NCOLS + j]->weightSwapOut(weightSwapFinal);
        } else {
          pe[i * NCOLS + j]->weightSwapOut(weightSwap[i * NCOLS + j]);
        }

        pe[i * NCOLS + j]->inputOut(inputConnection[i][j]);
        pe[i * NCOLS + j]->weightOut(weightConnection[i + 1][j]);
        if (i == NROWS - 1) {
          pe[i * NCOLS + j]->psumOut(outputs[j]);
        } else {
          pe[i * NCOLS + j]->psumOut(psumConnection[i][j]);
        }
      }
    }

    SC_THREAD(sendWeights);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(checkSwapDone);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    // Tie off unused Connections
    Tieoff<IDTYPE> *inputConnectionTieoff[NROWS];
    for (int i = 0; i < NROWS; i++) {
      inputConnectionTieoff[i] =
          new Tieoff<IDTYPE>(sc_gen_unique_name("tieoff"));
      inputConnectionTieoff[i]->in(inputConnection[i][NCOLS - 1]);
    }

    Tieoff<ac_int<1, false> > *weightSwapTieoff[NROWS - 1];
    for (int i = 0; i < NROWS - 1; i++) {
      weightSwapTieoff[i] =
          new Tieoff<ac_int<1, false> >(sc_gen_unique_name("tieoff"));
      weightSwapTieoff[i]->in(weightSwap[i * NCOLS + NCOLS - 1]);
    }
  }

  void sendWeights() {
    for (int j = 0; j < NCOLS; j++) {
      weightConnection[0][j].write(WDTYPE());
    }
    weightValid.write(false);
    weights.Reset();

    wait();

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
    while (true) {
      Pack1D<WDTYPE, NCOLS> arrayWeights;
      if (weights.PopNB(arrayWeights)) {
        weightValid.write(true);

      } else {
        weightValid.write(false);
      }

#pragma hls_unroll yes
      for (int i = 0; i < NCOLS; i++) {
        weightConnection[0][i].write(arrayWeights[i]);
      }

      wait();
    }
  }

  void checkSwapDone() {
    weightSwapFinal.ResetRead();
    weightSwapDone.Reset();

    wait();

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
    while (true) {
      ac_int<1, false> swap = weightSwapFinal.Pop();
      if (swap) {
        weightSwapDone.SyncPush();
      }
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_SYSTOLIC_ARRAY_H
